using LinearAlgebra
using JLD
using Optim, LineSearches
using StatsBase
using ProgressMeter
using Kronecker

## Given a matrix, return the leverage scores

function leverages(A)
    # X = Matrix(qr(A).Q) = Economic Q factor from the QR
    # sum.(eachrow(abs2.(X))) = Squared L2 norm of each row of X
    qrD = qr(A);
    return sum.(eachrow(abs2.(Matrix(qrD.Q))));
end

## Given a matrix, return the Lewis Weights after 10 iterations.
# 4 iterations was enough for Vandermonde matrices, but doesn't quiet suffice here.

function lewis_weights(A, p=1, mode=1)

    if p==2
        return leverages(A)
    end

    n,d = size(A)
    w = ones(n)
    for t=1:10
        # println(w)
        W = Diagonal(w.^(1/2 - 1/p))
        levs = leverages(W * A)
        w = (w.^(2/p - 1) .* levs).^(p/2)
    end
    return w
end


## Subsampling

function subsample(A, b, probs, p, s)
    # A in an (n, d) matrix
    # b is a vector of length n
    # probs is a vector of length n
    # p is the Lp parameter, used in the rescaling
    # s is the number of rows to keep
    # Subsample s entries of A wrt the entries of p

    (n,d) = size(A)
    probs = probs / sum(probs) # sums to 1 now for sure
    samples = sample(1:n, Weights(probs), s, replace=true) # List of row id's to keep
    rescaling_factors = (probs*s).^(-1/p) # Rescaling factors

    SA = (Diagonal(rescaling_factors) * A)[samples',:][1,:,:] # Rescale A, keep certain rows, and drop an extra index that appears
    Sb = (Diagonal(rescaling_factors) * b)[samples]
    return SA, Sb
end

## Kronecker expansion

function expand_matrix(A, folds)
    # Expands the matrix A via tensor product "folds" many times.
    # In the paper, "folds" is r in Algorithm 3
    (n,d) = size(A)
    # @show d, folds
    A_big = zeros(n, d^folds)
    for i=1:n
        A_big[i,:] = kronecker(Matrix(A[i:i,:]), folds) # Tensor product to expand each row
    end
    return A_big
end

## Lp Regression

function LpRegression(A,b,p)
    (n,d) = size(A)

    # Optim.jl -- black box optimization
    f(x) = norm(A*x-b,p)
    soln = optimize(f, zeros(d), LBFGS(linesearch=LineSearches.BackTracking()); autodiff = :forward)
    return soln.minimizer

end


## Uniform Sampling for LpRegression (i.e. alternative to Algorithm 3 in the paper)

function LpRegression_uniform(A, b, p, s)
    n = size(A,1)

    T, b_sampled = subsample(A, b, ones(n), p, s)

    return LpRegression(T, b_sampled, p)
end

## Run the Lp approx vs. exact regression problem many times, and store the output in a JLD file
# The JLD file can then be opened to compute medians and quartiles for the LaTeX Figure Graphics

function lp_statistical_complexity()

    # A has size n by d
    n = 25000
    d = 10

    # Size of the top-left block: 100 rows and 6 columns
    n0 = 100
    d0 = 6

    # Build the matrix A and response b
    A = [
        randn(n0,d0) zeros(n0,d-d0)
        zeros(n-n0,d0) randn(n-n0,d-d0)
    ]
    b = A*[100*randn(d0); 1*randn(d-d0)] + 1*randn(n)
    true_soln = LpRegression(A, b, p) # True Lp Regression solution vector

    # Lp norm parameter
    p = 6

    # Number of trials to smoothen over
    n_trials = 100

    # Number of different m values to try
    resolution = 8
    m_range = Int.(floor.(10 .^ (range(log10(1), log10(1000), length=resolution))))

    # Store the errors from our method ("fast_") and from uniform sampling ("unif_")
    fast_rel_errs = zeros(n_trials, length(m_range))
    unif_rel_errs = zeros(n_trials, length(m_range))

    # Compute the tensored matrix, and its lewis weights
    r = Int(floor(log2(p) - 1))
    q = p/(2^r) # in [2,4]
    M = expand_matrix(A, 2^r)
    lw = lewis_weights(M , q) # No longer approx
    @show "Finished Computing Lewis Weights"

    @showprogress for t=1:n_trials
        for i=1:length(m_range)
            T, b_sampled = subsample(A, b, lw, q, m_range[i]) # Subsample A and b
            fast_est_soln = LpRegression(T, b_sampled, p) # Compute the subsample's solution vector

            unif_est_soln = LpRegression_uniform(A, b, p, m_range[i]) # Compute a uniform subsampling's solution vector

            # Compute eps_relative for both vectors
            fast_rel_errs[t,i] = abs(norm(A * true_soln - b, p) - norm(A * fast_est_soln - b, p)) / norm(A * true_soln - b, p)
            unif_rel_errs[t,i] = abs(norm(A * true_soln - b, p) - norm(A * unif_est_soln - b, p)) / norm(A * true_soln - b, p)
        end
    end

    # Save all of this as a JLD (sort of like a Julia version of a json file)
    JLD.save("Lp_regression_data_randn_25k_d10_p6_50trials_d06_n100_log.jld",
             "fast_rel_errs", fast_rel_errs,
             "unif_rel_errs", unif_rel_errs,
             "d_base", d, "p_base", p, "s_range", m_range,
             "n", n, "n_trials", n_trials, "resolution", resolution)
end

##

lp_statistical_complexity()
